import numpy as np
import matplotlib.pyplot as plt

from sklearn import datasets

iris = datasets.load_iris()
x = iris.data[:, 2:]
y = iris.target

# plt.scatter(x[y==0,0], x[y==0,1])
# plt.scatter(x[y==1,0], x[y==1,1])
# plt.scatter(x[y==2,0], x[y==2,1])

# plt.show()

from sklearn.tree import DecisionTreeClassifier
# tree = DecisionTreeClassifier(max_depth=6, criterion='entropy') #entropy
tree = DecisionTreeClassifier(max_depth=2) #gngi
tree.fit(x, y)
from sklearn.tree import plot_tree
plot_tree(tree, filled=True)
plt.show()
